{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deploy HPO with Custom Algorithms\n",
    "\n",
    "In this chapter, we will focus on deploying HPO with custom algorithms, emphasizing the details rather than the overall workflow. A brief introduction to HPO deployment is provided in the previous chapter, [Efficient HPO with EvoX](#/guide/user/3-hpo), and prior reading is highly recommended."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Making Algorithms Parallelizable\n",
    "\n",
    "Since we need to transform the inner algorithm into the problem, it's crucial that the inner algorithm is parallelizable. Therefore, some modifications to the algorithm may be necessary.\n",
    "\n",
    "To ensure the function is JIT-compilable, it must meet the conditions outlined in [JIT components](#/guide/developer/2-jit-able). In addition to these requirements, the algorithm must also satisfy the following two constraints:\n",
    "\n",
    "1. The algorithm should have no methods with in-place operations on the attributes of the algorithm itself.\n",
    "\n",
    "```python\n",
    "class ExampleAlgorithm(Algorithm):\n",
    "    def __init__(self,...): \n",
    "        self.pop = torch.rand(10,10) #attribute of the algorithm itself\n",
    "        pass\n",
    "\n",
    "    def step_in_place(self): # method with in-place operations\n",
    "        self.pop.copy_(pop)\n",
    "        pass\n",
    "\n",
    "    def step_out_of_place(self): # method without in-place operations\n",
    "        self.pop = pop\n",
    "        pass\n",
    "```\n",
    "\n",
    "2. The code logic does not rely on python control flow.\n",
    "\n",
    "```python\n",
    "class ExampleAlgorithm(Algorithm):\n",
    "    def __init__(self,...): \n",
    "        self.pop = rand(10,10) #attribute of the algotirhm itself\n",
    "        pass\n",
    "\n",
    "    def plus(self, y):\n",
    "        self.pop += y\n",
    "        pass\n",
    "\n",
    "    def minus(self, y):\n",
    "        self.pop -= y\n",
    "        pass      \n",
    "\n",
    "    def step_with_python_control_flow(self, y): # function with python control flow\n",
    "        x = rand()\n",
    "        if x>0.5:\n",
    "            self.plus(y)\n",
    "        else:\n",
    "            self.minus(y)\n",
    "        pass\n",
    "\n",
    "    def step_without_python_control_flow(self, y): # function without python control flow\n",
    "        x = rand()\n",
    "        cond = x > 0.5\n",
    "        _if_else_ = TracingCond(self.plus, self.minus)\n",
    "        _if_else_.cond(cond,y)\n",
    "        self.pop = pop\n",
    "        pass\n",
    "```\n",
    "\n",
    "In EvoX, we can easily make the algorithm parallelizable by the [`@trace_impl`](#trace_impl) decorator. \n",
    "\n",
    "The parameter of this decorator is a non-parallelizable function, and the decorated function is a rewrite of the original function. Detailed introduction of [`@trace_impl`](#trace_impl) can be found in [JIT Components](#/guide/developer/2-jit-able).\n",
    "\n",
    "Under this mechanism, we can retain the original function for use outside HPO tasks while enabling efficient computation within HPO tasks. Moreover, this modification is highly convenient."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utilizing the HPOMonitor\n",
    "\n",
    "In the HPO task, we should use the [`HPOMonitor`](#HPOMonitor) to track the metrics of each inner algorithm. The [`HPOMonitor`](#HPOMonitor) adds only one method, `tell_fitness`, compared to the standard [`monitor`](#Monitor). This addition is designed to offer greater flexibility in evaluating metrics, as HPO tasks often involve multi-dimensional and complex metrics.\n",
    "\n",
    "Users only need to create a subclass of [`HPOMonitor`](#HPOMonitor) and override the `tell_fitness` method to define custom evaluation metrics.\n",
    "\n",
    "We also provide a simple [`HPOFitnessMonitor`](#HPOFitnessMonitor), which supports calculating the 'IGD' and 'HV' metrics for multi-objective problems, and the minimum value for single-objective problems."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A simple example\n",
    "\n",
    "Here, we'll demonstrate a simple example of how to use HPO with EvoX. We will use the [`PSO`](#PSO) algorithm to search for the optimal hyper-parameters of a basic algorithm to solve the sphere problem.\n",
    "\n",
    "First, let's import the necessary modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from evox.algorithms.pso_variants.pso import PSO\n",
    "from evox.core import Algorithm, Mutable, Parameter, Problem, trace_impl\n",
    "from evox.problems.hpo_wrapper import HPOFitnessMonitor, HPOProblemWrapper\n",
    "from evox.utils import TracingCond\n",
    "from evox.workflows import EvalMonitor, StdWorkflow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we define an simple sphere problem. Note that this has no difference from the common [`problems`](#evox.problems)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Sphere(Problem):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def evaluate(self, x: torch.Tensor):\n",
    "        return (x * x).sum(-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we define the algorithm. The original `step` function is non-parallelizable, so we rewrite it using the [`@trace_impl`](#trace_impl) decorator to make it parallelizable. Specifically, we modify in-place operations and adjust the Python control flow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExampleAlgorithm(Algorithm):\n",
    "    def __init__(self, pop_size: int, lb: torch.Tensor, ub: torch.Tensor):\n",
    "        super().__init__()\n",
    "        assert lb.ndim == 1 and ub.ndim == 1, f\"Lower and upper bounds shall have ndim of 1, got {lb.ndim} and {ub.ndim}\"\n",
    "        assert lb.shape == ub.shape, f\"Lower and upper bounds shall have same shape, got {lb.ndim} and {ub.ndim}\"\n",
    "        self.pop_size = pop_size\n",
    "        self.hp = Parameter([1.0, 2.0, 3.0, 4.0])  # the hyperparameters to be optimized\n",
    "        self.lb = lb\n",
    "        self.ub = ub\n",
    "        self.dim = lb.shape[0]\n",
    "        self.pop = Mutable(torch.empty(self.pop_size, lb.shape[0], dtype=lb.dtype, device=lb.device))\n",
    "        self.fit = Mutable(torch.empty(self.pop_size, dtype=lb.dtype, device=lb.device))\n",
    "\n",
    "    def strategy_1(self, pop):  # one update strategy\n",
    "        pop = pop * (self.hp[0] + self.hp[1])\n",
    "        self.pop = pop\n",
    "\n",
    "    def strategy_2(self, pop):  #  the other update strategy\n",
    "        pop = pop * (self.hp[2] + self.hp[3])\n",
    "        self.pop = pop\n",
    "\n",
    "    def step(self):\n",
    "        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)  # simply random sampling\n",
    "        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]\n",
    "        control_number = torch.rand()\n",
    "        if control_number < 0.5:  # conditional control\n",
    "            pop = self.strategy_1(pop)\n",
    "        else:\n",
    "            pop = self.strategy_2(pop)\n",
    "        self.pop.copy_(pop)  # in-place update\n",
    "        self.fit.copy_(self.evaluate(pop))\n",
    "\n",
    "    # (using class methods for control flow)\n",
    "    @trace_impl(step)  # rewrite the step function to support vmap\n",
    "    def trace_step(self):\n",
    "        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)\n",
    "        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]\n",
    "        pop = pop * self.hp[0]\n",
    "        control_number = torch.rand()\n",
    "        cond = control_number < 0.5\n",
    "        # Deal with the conditional control flow equivalent in tracing\n",
    "        branches = (self.strategy_1, self.strategy_2)\n",
    "        state, names = self.prepare_control_flow(*branches)\n",
    "        _if_else_ = TracingCond(*branches)\n",
    "        state = _if_else_.cond(state, cond, pop)\n",
    "        self.after_control_flow(state, *names)\n",
    "        # Evaluate\n",
    "        self.fit = self.evaluate(pop)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To handle the Python control flow, we use [`TracingCond`](#TracingCond), [`TracingWhile`](#TracingWhile) and [`TracingSwitch`](#TracingSwitch). Since, in tracing mode, variables outside the method may be incorrectly interpreted as static variables, we need to use state to track them. A detailed introduction to [`TracingCond`](#TracingCond), [`TracingWhile`](#TracingWhile) and [`TracingSwitch`](#TracingSwitch) can be found in [JIT Components](#/guide/developer/2-jit-able). Below, we provide two equivalent implementations for the `trace_step` method.\n",
    "\n",
    "```python\n",
    "# Equivalent to the following code (Local function style)\n",
    "\n",
    "    @trace_impl(step)  # rewrite the step function to support vmap\n",
    "    def trace_step(self):\n",
    "        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)\n",
    "        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]\n",
    "        pop = pop * self.hp[0]\n",
    "        control_number = torch.rand()\n",
    "        cond = control_number < 0.5\n",
    "        # Deal with the conditional control flow equivalent in tracing\n",
    "        branches = (lambda: pop * self.hp[1], lambda: pop * self.hp[2])\n",
    "        state, names = self.prepare_control_flow(*branches)\n",
    "        _if_else_ = TracingCond(*branches, stateful_functions=True)\n",
    "        state, pop = _if_else_.cond(state, cond)\n",
    "        self.after_control_flow(state, *names)\n",
    "        # Evaluate\n",
    "        self.pop = pop\n",
    "        self.fit = self.evaluate(pop)\n",
    "\n",
    "\n",
    "# Equivalent to the following code (Pure function style)\n",
    "\n",
    "    @trace_impl(step)  # rewrite the step function to support vmap\n",
    "    def trace_step(self):\n",
    "        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)\n",
    "        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]\n",
    "        pop = pop * self.hp[0]\n",
    "        control_number = torch.rand()\n",
    "        cond = control_number < 0.5\n",
    "        # Deal with the conditional control flow equivalent in tracing\n",
    "        branches = (lambda p, hp: p * hp[1], lambda p, hp: p * hp[2])\n",
    "        _if_else_ = TracingCond(*branches, stateful_functions=False) # defaults to False for no member function\n",
    "        pop = _if_else_.cond(cond, pop, self.hp)\n",
    "        # Evaluate\n",
    "        self.pop = pop\n",
    "        self.fit = self.evaluate(pop)\n",
    "```\n",
    "\n",
    "Next, we can use the [`StdWorkflow`](#StdWorkflow) to wrap the [`problem`](#evox.problems), [`algorithm`](#evox.algorithms) and [`monitor`](#Monitor). Then we use the [`HPOProblemWrapper`](#HPOProblemWrapper) to transform the [`StdWorkflow`](#StdWorkflow) to HPO problem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_default_device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "inner_algo = ExampleAlgorithm(10, -10 * torch.ones(8), 10 * torch.ones(8))\n",
    "inner_prob = Sphere()\n",
    "inner_monitor = HPOFitnessMonitor()\n",
    "inner_monitor.setup()\n",
    "inner_workflow = StdWorkflow()\n",
    "inner_workflow.setup(inner_algo, inner_prob, monitor=inner_monitor)\n",
    "# Transform the inner workflow to an HPO problem\n",
    "hpo_prob = HPOProblemWrapper(iterations=9, num_instances=7, workflow=inner_workflow, copy_init_state=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can test whether the [`HPOProblemWrapper`](#HPOProblemWrapper) correctly recognizes the hyper-parameters we defined. Since we have made no modifications to the hyper-parameters for the 7 instances, they should be identical across all instances."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init params:\n",
      " {'self.algorithm.hp': Parameter containing:\n",
      "tensor([[1., 2., 3., 4.],\n",
      "        [1., 2., 3., 4.],\n",
      "        [1., 2., 3., 4.],\n",
      "        [1., 2., 3., 4.],\n",
      "        [1., 2., 3., 4.],\n",
      "        [1., 2., 3., 4.],\n",
      "        [1., 2., 3., 4.]], device='cuda:0')}\n"
     ]
    }
   ],
   "source": [
    "params = hpo_prob.get_init_params()\n",
    "print(\"init params:\\n\", params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also specify our own set of hyperparameter values. Note that the number of hyperparameter sets must match the number of instances in the [`HPOProblemWrapper`](#HPOProblemWrapper). The custom hyper-parameters should be provided as a dictionary whose values are wrapped in the [`Parameter`](#Parameter)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "params:\n",
      " {'self.algorithm.hp': Parameter containing:\n",
      "tensor([[0.8108, 0.7703, 0.8577, 0.0708],\n",
      "        [0.3465, 0.7551, 0.0136, 0.5634],\n",
      "        [0.9978, 0.8935, 0.7606, 0.9789],\n",
      "        [0.9837, 0.4787, 0.5919, 0.2196],\n",
      "        [0.9336, 0.8979, 0.8039, 0.0677],\n",
      "        [0.7770, 0.4149, 0.8965, 0.6570],\n",
      "        [0.1422, 0.5341, 0.6108, 0.5978]], device='cuda:0')} \n",
      "\n",
      "result:\n",
      " tensor([77.0704, 15.8463, 21.6154, 40.8018, 43.6397, 55.0446,  2.4755],\n",
      "       device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "params = hpo_prob.get_init_params()\n",
    "# since we have 7 instances, we need to pass 7 sets of hyperparameters\n",
    "params[\"self.algorithm.hp\"] = torch.nn.Parameter(torch.rand(7, 4), requires_grad=False)\n",
    "result = hpo_prob.evaluate(params)\n",
    "print(\"params:\\n\", params, \"\\n\")\n",
    "print(\"result:\\n\", result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we use the [`PSO`](#PSO) algorithm to optimize the hyper-parameters of `ExampleAlgorithm`. Note that the population size of the [`PSO`](#PSO) must match the number of instances; otherwise, unexpected errors may occur. In this case, we need to transform the solution in the outer workflow, as the [`HPOProblemWrapper`](#HPOProblemWrapper) requires a dictionary as input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "params:\n",
      " tensor([[0.0031, 0.4910, 1.8519, 1.2221]], device='cuda:0') \n",
      "\n",
      "result:\n",
      " tensor([0.0012], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "class solution_transform(torch.nn.Module):\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        return {\"self.algorithm.hp\": x}\n",
    "\n",
    "\n",
    "outer_algo = PSO(7, -3 * torch.ones(4), 3 * torch.ones(4))\n",
    "monitor = EvalMonitor(full_sol_history=False)\n",
    "outer_workflow = StdWorkflow()\n",
    "outer_workflow.setup(outer_algo, hpo_prob, monitor=monitor, solution_transform=solution_transform())\n",
    "outer_workflow.init_step()\n",
    "for _ in range(20):\n",
    "    outer_workflow.step()\n",
    "monitor = outer_workflow.get_submodule(\"monitor\")\n",
    "print(\"params:\\n\", monitor.topk_solutions, \"\\n\")\n",
    "print(\"result:\\n\", monitor.topk_fitness)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
